package mosdi.matching;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import mosdi.util.ArrayUtils;
import mosdi.util.BitArray;

public class PositionWeightMatrix {
	private int alphabetSize;
	private int columns;
	private double[][] matrix;
	// column specific minimium and maximum score
	private double[] minScores;
	private double[] maxScores;

	public PositionWeightMatrix(double[][] pwm) {
		this.matrix = pwm;
		alphabetSize = matrix.length;
		columns = matrix[0].length;
		minScores = new double[columns];
		maxScores = new double[columns];
		Arrays.fill(minScores, Double.POSITIVE_INFINITY);
		Arrays.fill(maxScores, Double.NEGATIVE_INFINITY);
		for (int c=0; c<alphabetSize; ++c) {
			for (int column=0; column<columns; ++column) {
				minScores[column] = Math.min(minScores[column], matrix[c][column]);
				maxScores[column] = Math.max(maxScores[column], matrix[c][column]);
			}
		}
	}
	
	public PositionWeightMatrix(double[][] pfm, double[] characterDistribution, double pseudoCounts) {
		double[] background = ArrayUtils.normalizedCopy(characterDistribution);
		alphabetSize = pfm.length;
		columns = pfm[0].length;
		double[] columnSums = new double[columns];
		for (int c=0; c<alphabetSize; ++c) {
			for (int column=0; column<columns; ++column) columnSums[column] += pfm[c][column];
		}
		minScores = new double[columns];
		maxScores = new double[columns];
		Arrays.fill(minScores, Double.POSITIVE_INFINITY);
		Arrays.fill(maxScores, Double.NEGATIVE_INFINITY);
		matrix = new double[alphabetSize][columns];
		for (int c=0; c<alphabetSize; ++c) {
			for (int column=0; column<columns; ++column) {
				double p = Math.log(((pfm[c][column]+pseudoCounts)/(columnSums[column]+alphabetSize*pseudoCounts)) / background[c]);
				matrix[c][column] = p;
				minScores[column] = Math.min(minScores[column], p);
				maxScores[column] = Math.max(maxScores[column], p);
			}
		}
	}

	/** Return the number of matrix columns. */
	public int width() {
		return columns;
	}
	
	/** Minimum possible score of a sequence of same width as PWM. */
	public double minScore() {
		return ArrayUtils.sum(minScores);
	}

	/** Maximum possible score of a sequence of same width as PWM. */
	public double maxScore() {
		return ArrayUtils.sum(maxScores);	
	}

	/** Returns the smallest position>=startPosition such that the string that start
	 *  there has a score >= threshold. Returns -1 of no such position exists. */
	public int nextMatch(int[] string, int startPosition, double threshold) { 
		for (int i=startPosition; i<string.length-columns+1; ++i) {
			double score = score(string,i);
			if (Double.isNaN(score)) continue;
			if (score>=threshold) return i;
		}
		return -1;
	}
	
	/** Returns the the maximal score over all positions in the given string. */
	public double maxScore(int[] string) {
		double maxScore = Double.NEGATIVE_INFINITY;
		for (int i=0; i<string.length-columns+1; ++i) {
			double score = score(string,i);
			if (Double.isNaN(score)) continue;
			maxScore = Math.max(maxScore, score);
		}
		return maxScore;
	}

	/** Comparator to sort column indices according to difference between min and max score. */
	private class ColumnIndexComparator implements Comparator<Integer> {
		@Override
		public int compare(Integer o1, Integer o2) {
			double diff1 = maxScores[o1] - minScores[o1];
			double diff2 = maxScores[o2] - minScores[o2];
			return -Double.compare(diff1, diff2);
		}
	}
	
	private class GeneralizedStringGenerator {
		private List<int[]> results;
		private Map<BitArray, Integer> characterMap;
		private int[] alphabetTranslator;
		private int wildcard;
		Integer[] order;
		private double[] cumulativeMaxScores;
		private double[] cumulativeMinScores;
		private GeneralizedStringGenerator(Map<BitArray, Integer> characterMap) {
			this.characterMap = characterMap;
			BitArray all = new BitArray(alphabetSize);
			all.invert();
			this.wildcard = characterMap.get(all);
			alphabetTranslator = new int[alphabetSize];
			for (int c=0; c<alphabetSize; ++c) {
				BitArray b = new BitArray(alphabetSize);
				b.set(c, true);
				alphabetTranslator[c] = characterMap.get(b);
			}
		}
		private List<int[]> generate(double threshold) { 
			order = new Integer[columns];
			for (int i=0; i<columns; ++i) order[i] = i;
			Arrays.sort(order, new ColumnIndexComparator());
			cumulativeMinScores = new double[columns];
			cumulativeMaxScores = new double[columns];
			for (int i=columns-1; i>0; --i) {
				int col = order[i-1];
				int lastCol = order[i];
				cumulativeMinScores[col] = minScores[col] + cumulativeMinScores[lastCol];
				cumulativeMaxScores[col] = maxScores[col] + cumulativeMaxScores[lastCol];
			}
			results = new ArrayList<int[]>();
			int[] current = new int[columns];
			Arrays.fill(current, wildcard);
			generateRecursively(current, 0, threshold);
			return results;
		}
		private void generateRecursively(int[] current, int depth, double threshold) {
			if (depth==columns) {
				results.add(current);
				return;
			}
			int column = order[depth];
			BitArray sufficientChars = new BitArray(alphabetSize);
			for (int c=0; c<alphabetSize; ++c) {
				if (matrix[c][column]+cumulativeMaxScores[column]<threshold) continue;
				if (matrix[c][column]+cumulativeMinScores[column]>=threshold) {
					sufficientChars.set(c, true);
				} else {
					int[] s = Arrays.copyOf(current, current.length);
					s[column] = alphabetTranslator[c];
					generateRecursively(s, depth+1, threshold-matrix[c][column]);
				}
			}
			if (!sufficientChars.allZero()) {
				current[column] = characterMap.get(sufficientChars);
				results.add(current);
			}
		}
	}
	
	/** Returns a list of IUPAC strings that match exactly those strings with a 
	 *  PWM score >= threshold. */
	public List<int[]> toGeneralizedStrings(double threshold, BitArray[] generalizedAlphabet) {
		Map<BitArray, Integer> characterMap = new HashMap<BitArray, Integer>();
		for (int i=0; i<generalizedAlphabet.length; ++i) {
			characterMap.put(generalizedAlphabet[i], i);
		}
		GeneralizedStringGenerator gsg = new GeneralizedStringGenerator(characterMap);
		return gsg.generate(threshold);
	}
	
	public double getEntry(int character, int position) {
		return matrix[character][position];
	}
	
	@Override
	public String toString() {
		StringBuffer sb = new StringBuffer();
		for (int c=0; c<alphabetSize; ++c) {
			for (int column=0; column<columns; ++column) {
				if (column>0) sb.append('\t');
				sb.append(matrix[c][column]);
			}
			sb.append('\n');
		}
		return sb.toString();
	}
	
	public double score(int[] string, int startPosition) {
		double score = 0.0d;
		for (int i=0; i<columns; ++i) {
			int c = string[startPosition+i];
			if (c==-1) return Double.NaN;
			score += matrix[c][i];
		}
		return score;
	}

	public double score(int[] string) {
		if (string.length!=columns) throw new IllegalArgumentException("Length mismatch.");
		return score(string,0);
	}
	
}
